import torch


def GradOPS(grads, task_num, alpha):
    def generate_unit_bases(grads, idx):
        unit_bases = []
        for i in range(task_num):
            v = grads[i]
            if i == idx:
                continue
            proj = 0.
            for base in unit_bases:
                proj = proj + torch.dot(v, base) * base
            v = v - proj
            unit_bases.append(v / v.norm())
        return unit_bases

    def generate_projection_on_bases(v, unit_bases):
        proj = 0.
        for base in unit_bases:
            scale = torch.dot(v, base)
            proj = proj + scale * base
        return proj

    def project_on_span(grads):
        proj_grads = grads.clone()
        for i in range(task_num):
            conflict_flag = 0
            for j in range(task_num):
                if i == j:
                    continue
                cos_ij = torch.dot(grads[i], grads[j]) / (grads[i].norm() * grads[j].norm())
                if cos_ij < 0:
                    conflict_flag = 1
                    break
            if conflict_flag:
                unit_bases = generate_unit_bases(grads, i)
                proj = generate_projection_on_bases(grads[i], unit_bases)
                proj_grads[i] = grads[i] - proj
            else:
                proj_grads[i] = grads[i]
        return proj_grads
    
    proj_grads = project_on_span(grads)
    proj_G = proj_grads.sum(0)
    L = torch.matmul(proj_G, grads.T) / torch.norm(grads, p=2, dim=-1)
    r = L / L.mean(0)
    w = torch.pow(r, alpha)
    w = w / w.mean(0)
    new_G = torch.matmul(w, proj_grads)
    return new_G
